import torch
import collections

collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping

import tltorch
import math

import tensorly as tly

tly.set_backend('pytorch')


class Conv2d_tucker_BUG_adaptive(torch.nn.Conv2d):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, groups=1, bias=True,
                 dilation=1, tau=0.1, start_rank_percent=0.3) -> None:
        """
        Initializer for the convolutional low rank layer (filterwise), extention of the classical Pytorch's convolutional layer.
        INPUTS:
        in_channels: number of input channels (Pytorch's standard)
        out_channels: number of output channels (Pytorch's standard)
        kernel_size : kernel_size for the convolutional filter (Pytorch's standard)
        dilation : dilation of the convolution (Pytorch's standard)
        padding : padding of the convolution (Pytorch's standard)
        stride : stride of the filter (Pytorch's standard)
        bias  : flag variable for the bias to be included (Pytorch's standard)
        step : string variable ('K','L' or 'S') for which forward phase to use
        rank : rank variable, None if the layer has to be treated as a classical Pytorch Linear layer (with weight and bias). If
                it is an int then it's either the starting rank for adaptive or the fixed rank for the layer.
        fixed : flag variable, True if the rank has to be fixed (KLS training on this layer)
        load_weights : variables to load (Pytorch standard, to finish)
        dtype : Type of the tensors (Pytorch standard, to finish)
        """
        super().__init__(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,
                         bias=bias, dilation=dilation)

        self.tau = tau
        low_rank_percent = start_rank_percent
        self.dims = [self.out_channels, self.in_channels] + list(self.kernel_size)

        # make sure that there are at least 3 channels, for rgb images
        self.rank = [max(int(d * low_rank_percent), 3) for d in self.dims[:2]] + self.dims[2::]
        self.r_old = [d for d in self.rank]
        self.rmax = [int(d) for d in self.dims]
        self.rmin = [min(3, d) for d in self.dims]

        self.C = torch.nn.Parameter(torch.empty(size=self.rmax), requires_grad=True)
        self.Us = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.empty(size=(d, r)), requires_grad=True) for d, r in
             zip(self.dims, self.rmax)])

        self.U_olds = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.empty(size=(d, r)), requires_grad=False) for d, r in zip(self.dims, self.rmax)])

        self.Ms = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.empty(size=(r, r)), requires_grad=False) for r in self.rmax])

        del self.weight  # remove original kernel
        self.reset_tucker_parameters()  # parameter intitialization

    @torch.no_grad()
    def reset_tucker_parameters(self):
        torch.nn.init.kaiming_uniform_(self.C, a=math.sqrt(5))
        for i in range(len(self.dims)):
            torch.nn.init.kaiming_uniform_(self.Us[i], a=math.sqrt(5))

            # Orthonormalize bases
            self.Us[i].data, _ = torch.linalg.qr(self.Us[i].data, 'reduced')

    def forward(self, input):
        """
        forward phase for the convolutional layer. It has to contain the three different
        phases for the steps 'K','L' and 'S' in order to be optimizable using dlrt.
        Every step is rewritten in terms of the tucker decomposition of the kernel tensor
        """
        C = self.C[:self.rank[0], :self.rank[1], :self.rank[2], :self.rank[3]]
        Us = [U[:, :self.rank[i]] for i, U in enumerate(self.Us)]

        result = tltorch.functional.tucker_conv(input, tucker_tensor=tltorch.TuckerTensor(C, Us, rank=self.rank),
                                                bias=self.bias, stride=self.stride, padding=self.padding,
                                                dilation=self.dilation)
        # No bias!
        return result

    @torch.no_grad()
    def step(self, dlrt_step="C", lr=0.05):

        if dlrt_step == "K":
            self.r_old = [r for r in self.rank]
            for i in range(len(self.Us)):
                # u_old = u.data.clone()
                # self.Us[i].data[:, :self.r_old[i]] = self.Us[i].data[:, :self.r_old[i]] - lr * self.Us[i].grad[:,
                #                                                                               :self.r_old[i]]

                aug_basis = torch.cat((self.Us[i].data[:, :self.r_old[i]], self.Us[i].grad[:, :self.r_old[i]]),
                                      axis=1)
                self.rank[i] = min(2 * self.r_old[i], self.rmax[i])  # take new rank as 2*r or max rank

                self.Us[i].data[:, : self.rank[i]], _ = torch.linalg.qr(aug_basis[:, : self.rank[i]], 'reduced')

                # dims:r x r_old
                self.Ms[i].data[:self.rank[i], : self.r_old[i]] = self.Us[i][:, :self.rank[i]].T @ self.U_olds[i][:,
                                                                                                   :self.r_old[i]]
            # project C onto new basis
            self.C.data[:self.rank[0], :self.rank[1], :self.rank[2], :self.rank[3]] = torch.einsum(
                'abcd,ia,jb,kc,ld->ijkl',
                self.C.data[:self.r_old[0], :self.r_old[1], :self.r_old[2], :self.r_old[3]],
                self.Ms[0][:self.rank[0], : self.r_old[0]],
                self.Ms[1][:self.rank[1], : self.r_old[1]],
                self.Ms[2][:self.rank[2], : self.r_old[2]],
                self.Ms[3][:self.rank[3], : self.r_old[3]])

        elif dlrt_step == "C":
            # self.C.data[:self.rank[0], :self.rank[1], :self.rank[2], :self.rank[3]] = self.C.data[:self.rank[0],
            #                                                                          :self.rank[1], :self.rank[2],
            #                                                                          :self.rank[3]] - \
            #                                                                          lr * self.C.grad[:self.rank[0],
            #                                                                               :self.rank[1], :self.rank[2],
            #                                                                               :self.rank[3]]

            # rank trunctation
            self.truncate()

    @torch.no_grad()
    def truncate(self):
        r_hat = [r for r in self.rank]
        Ps = []
        for i in range(2):  # iterate over nodes, but not over the last two (since their dims is too small)
            MAT_i_C = tly.base.unfold(self.C[:r_hat[0], :r_hat[1], :r_hat[2], :r_hat[3]], mode=i)

            # try:
            P, d, _ = torch.linalg.svd(MAT_i_C, full_matrices=False)
            # except:
            #    print(r_hat)
            #    print(self.rank)
            #    print(self.dims)
            #    print(self.rmax)
            #    print(MAT_i_C.shape)
            #    exit(1)

            Ps.append(P)
            # print(torch.norm(self.S[:2 * r0, :2 * r0] - torch.matmul(torch.matmul(u2,torch.diag(d)),v2)))

            tol = self.tau * torch.linalg.norm(d)
            r_new = r_hat[i]
            for j in range(0, r_hat[i]):
                tmp = torch.linalg.norm(d[j:r_hat[i]])
                if tmp < tol:
                    r_new = j
                    break

            self.rank[i] = max(min(r_new, self.rmax[i]), self.rmin[i])  # rank update

            # update U
            self.Us[i].data[:, :self.rank[i]] = self.Us[i].data[:, :r_hat[i]] @ P[:r_hat[i], :self.rank[i]]

        # update Core
        self.C.data[:self.rank[0], :self.rank[1], :self.rank[2], :self.rank[3]] = torch.einsum(
            'abcd,ia,jb->ijcd', self.C.data[:r_hat[0], :r_hat[1], :r_hat[2], :r_hat[3]],
            Ps[0][:r_hat[0], : self.rank[0]].T, Ps[1][:r_hat[1], : self.rank[1]].T)  # ,
        # Ps[2][:r_hat[2], : self.rank[2]].T, Ps[3][:r_hat[3], : self.rank[3]].T)

    @torch.no_grad()
    def copy_U(self):
        for i in range(len(self.Us)):
            self.U_olds[i].data[:, :self.rank[i]] = self.Us[i].data[:, :self.rank[i]].clone()

    @torch.no_grad()
    def get_r_mod_i(self, i):
        return min(self.rank[i], math.prod([r for j, r in enumerate(self.rank) if j != i]))

    @torch.no_grad()
    def activate_grad_step(self, step="C"):
        if step == "K":
            for u in self.Us:
                u.requires_grad = True
            self.C.requires_grad = False
        if step == "C":
            for u in self.Us:
                u.requires_grad = False
            self.C.requires_grad = True

    @torch.no_grad()
    def set_grad_zero(self):
        for u in self.Us:
            u.grad.zero_()
        # self.C.grad.zero_()
